import pandas as pd
import numpy as np
from scipy import stats
import networkx as nx
import matplotlib.pyplot as plt
import copy
import torch
import torchvision
from nltk.corpus import wordnet as wn
import torchvision.transforms as transforms
import torchvision.models
from torch.nn import functional as F
from torch.utils.data import DataLoader
from torch.utils.data.sampler import SubsetRandomSampler
import torchvision.transforms as tvtf
import tqdm
import csv
import pickle
import os
import timm
from sklearn.model_selection import train_test_split
from hierarchy.hierarchy import Hierarchy
from hierarchy.inference_rules import *
import clip
from utils import log_utils
from utils.clip_utils import ImageNetClip
from utils.csv_utils import *
import argparse
from timeit import default_timer as timer
from hierarchy.inference_rules import get_inference_rule
from hierarchical_selective_classification import get_hierarchy

climb_inf_rule = ClimbingInferenceRule(get_hierarchy(rebuild_hier=False, load_hier=True, path='resources/imagenet1k_hier.pkl'))

def optimal_threshold_algorithm(hierarchy, y_scores_cal, y_true_cal, alpha=0.1):
    climb_inf_rule = get_inference_rule('Climbing', hierarchy)
    all_nodes_probs_cal = hierarchy.all_nodes_probs(y_scores_cal)
    preds_leaf_cal = y_scores_cal.max(dim=1)[1]
    correct_thetas = climb_inf_rule.get_tight_thresholds(all_nodes_probs_cal, preds_leaf_cal, y_true_cal)
    return climb_inf_rule.compute_quantile_threshold(correct_thetas, alpha=alpha)

def DARTS(hierarchy, y_scores_cal, y_true_cal, epsilon=0.1):
    # the reward for each node is: coverage * root entropy
    root_entropy = np.log2(hierarchy.num_leaves)
    rewards = hierarchy.coverage_vec * root_entropy
    # Step 1+2: get probabilities for all nodes and sum them upwards
    all_nodes_probs_cal = hierarchy.all_nodes_probs(y_scores_cal)
    # Step 3+4: init f_0, if its accuracy suffices then return it
    f_0_scores = rewards * all_nodes_probs_cal
    f_0_preds = f_0_scores.max(dim=0)[1]
    f_0_correctness = hierarchy.correctness(f_0_preds, y_true_cal).cpu()
    f_0_accuracy = f_0_correctness.sum().item() / len(f_0_correctness)
    if f_0_accuracy >= 1-epsilon:
        return 0
    # Step 5: calculate lambda bar
    r_max = rewards.max()
    r_root = rewards[hierarchy.root_index]
    lambda_bar = (r_max * (1-epsilon) - r_root) / epsilon
    # Step 6: binary search for optimal lambda
    min_lambda = 0
    max_lambda = lambda_bar.item()
    iteration_limit = 25
    confidence = 0.95
    desired_alpha = (1 - confidence) * 2
    num_examples = len(f_0_preds)
    for t in range(iteration_limit):
        lambda_t = (min_lambda + max_lambda) / 2
        f_t_scores = (rewards + lambda_t) * all_nodes_probs_cal
        f_t_preds = f_t_scores.max(dim=0)[1]
        f_t_correctness = hierarchy.correctness(f_t_preds, y_true_cal).cpu()
        f_t_accuracy = f_t_correctness.sum().item() / len(f_t_correctness)
        acc_bounds = stats.binom.interval(1-desired_alpha, num_examples, f_t_accuracy)
        acc_lower_bound = acc_bounds[0] / num_examples
        if acc_lower_bound > 1-epsilon:
            max_lambda = lambda_t
        else:
            min_lambda = lambda_t
    return max_lambda

class GraphDistanceLoss(torch.nn.Module):
    def __init__(self, hierarchy):
        super(GraphDistanceLoss, self).__init__()
        self.hierarchy = hierarchy
        self.lcas = hierarchy.lcas.to(hierarchy.device)

    def forward(self, preds, labels):
        return torch.mean(self.loss_per_sample(preds, labels))

    def loss_per_sample(self, preds, labels):
        lcas = self.lcas[preds,labels.int()].int()
        dist_to_lcas = self.hierarchy.dist_matrix[preds,lcas]
        loss = dist_to_lcas/self.hierarchy.root_height
        return loss
    
def conformal_risk_control(hierarchy, y_scores_cal, y_true_cal, alpha, B=1, loss='graph_distance'):
    lambdas = np.linspace(0, 1, 1001)
    all_nodes_probs_cal = hierarchy.all_nodes_probs(y_scores_cal)
    preds_leaf_cal = y_scores_cal.max(dim=1)[1]
    n = y_true_cal.shape[0]
    if loss == 'graph_distance':
        loss = GraphDistanceLoss(hierarchy)
        for lhat_idx, lam in enumerate(lambdas):
            _, hier_preds = climb_inf_rule.predict(all_nodes_probs_cal, preds_leaf_cal, lam)
            rhat = loss(hier_preds, y_true_cal)
            if (n/(n+1)) * rhat + B/(n+1) <= alpha:
                break
    elif loss == '01':
        for lhat_idx, lam in enumerate(lambdas):
            _, hier_preds = climb_inf_rule.predict(all_nodes_probs_cal, preds_leaf_cal, lam)
            rhat = 1 - hierarchy.correctness(hier_preds, y_true_cal).float().mean()
            if (n/(n+1)) * rhat + B/(n+1) <= alpha:
                break
    lhat_idx = max(lhat_idx - 1, 0) # Can't be -1.
    return lambdas[lhat_idx]

def validation(alg, hierarchy, y_scores_val, y_true_val, opt_result):
    preds_leaf_val = y_scores_val.max(dim=1)[1]
    all_nodes_probs_val = hierarchy.all_nodes_probs(y_scores_val)

    results = {}
    if alg == 'DARTS':
        opt_lambda = opt_result
        root_entropy = np.log2(hierarchy.num_leaves)
        rewards = hierarchy.coverage_vec * root_entropy
        probs = (rewards + opt_lambda) * all_nodes_probs_val
        preds = probs.max(dim=0)[1]
    elif alg == 'OTA' or alg == 'CRC':
        opt_theta = opt_result
        _, preds = climb_inf_rule.predict(all_nodes_probs_val, preds_leaf_val, opt_theta)

    hier_correctness = hierarchy.correctness(preds, y_true_val).cpu()
    results['hier_accuracy'] = hier_correctness.sum().item() / len(hier_correctness)
    results['coverage'] = hierarchy.coverage(preds)
    return results

threshold_algs = ['OTA', 'DARTS', 'CRC', 'CRC01']
def compare_threshold_algorithms(all_y_scores, all_y_true, model_name, temp_scaling):
    print(f'Running threshold algorithms for model: {model_name}')
    alpha_vals = [0.005, 0.01, 0.05, 0.1, 0.15, 0.2, 0.3]
    n = 10000
    n_repeats = 100
    inat = 'inat' in model_name
    path = 'resources/imagenet1k_hier.pkl' if not inat else 'resources/inat21.pkl'
    hierarchy = get_hierarchy(rebuild_hier=False, load_hier=True, path=path)
    results = {}
    for alpha in alpha_vals:
        alpha_results = {a:{'Optimal Param': [], 'Accuracy':[], 'Coverage':[]} for a in threshold_algs}
        print(f'alpha: {alpha}')
        timer_start = timer()
        # for rep in tqdm.tqdm(range(n_repeats)):
        for rep in range(n_repeats):
            # each rep produces a random calibration set (reprodicible across runs)
            cal_indices, val_indices = train_test_split(np.arange(len(all_y_true)), train_size=n, stratify=all_y_true.cpu())
            y_scores_cal = all_y_scores[cal_indices].cuda()
            y_scores_val = all_y_scores[val_indices].cuda()
            y_true_cal = all_y_true[cal_indices].long().cuda()
            y_true_val = all_y_true[val_indices].long().cuda()
            # split the val set to 2 parts
            if inat:
                y_scores_val_1, y_scores_val_2 = torch.split(y_scores_val, y_scores_val.shape[0] // 2)
                y_true_val_1, y_true_val_2 = torch.split(y_true_val, y_true_val.shape[0] // 2)
            
            # OTA
            opt_theta = optimal_threshold_algorithm(hierarchy, y_scores_cal, y_true_cal, alpha=alpha)
            alpha_results['OTA']['Optimal Param'].append(opt_theta)
            if inat:
                res_1 = validation('OTA', hierarchy, y_scores_val_1, y_true_val_1, opt_theta)
                res_2 = validation('OTA', hierarchy, y_scores_val_2, y_true_val_2, opt_theta)
                alpha_results['OTA']['Accuracy'].append((res_1['hier_accuracy'] + res_2['hier_accuracy'])/2)
                alpha_results['OTA']['Coverage'].append((res_1['coverage'] + res_2['coverage'])/2)
            else:
                res = validation('OTA', hierarchy, y_scores_val, y_true_val, opt_theta)
                alpha_results['OTA']['Accuracy'].append(res['hier_accuracy'])
                alpha_results['OTA']['Coverage'].append(res['coverage'])

            # DARTS
            opt_lambda = DARTS(hierarchy, y_scores_cal, y_true_cal, epsilon=alpha)
            alpha_results['DARTS']['Optimal Param'].append(opt_lambda)
            if inat:
                res_1 = validation('DARTS', hierarchy, y_scores_val_1, y_true_val_1, opt_lambda)
                res_2 = validation('DARTS', hierarchy, y_scores_val_2, y_true_val_2, opt_lambda)
                alpha_results['DARTS']['Accuracy'].append((res_1['hier_accuracy'] + res_2['hier_accuracy'])/2)
                alpha_results['DARTS']['Coverage'].append((res_1['coverage'] + res_2['coverage'])/2)
            else:
                res = validation('DARTS', hierarchy, y_scores_val, y_true_val, opt_lambda)
                alpha_results['DARTS']['Accuracy'].append(res['hier_accuracy'])
                alpha_results['DARTS']['Coverage'].append(res['coverage'])
         
            # CRC
            opt_lambda = conformal_risk_control(hierarchy, y_scores_cal, y_true_cal, alpha=alpha)
            alpha_results['CRC']['Optimal Param'].append(opt_lambda)
            if inat:
                res_1 = validation('CRC', hierarchy, y_scores_val_1, y_true_val_1, opt_theta)
                res_2 = validation('CRC', hierarchy, y_scores_val_2, y_true_val_2, opt_theta)
                alpha_results['CRC']['Accuracy'].append((res_1['hier_accuracy'] + res_2['hier_accuracy'])/2)
                alpha_results['CRC']['Coverage'].append((res_1['coverage'] + res_2['coverage'])/2)
            else:
                res = validation('CRC', hierarchy, y_scores_val, y_true_val, opt_lambda)
                alpha_results['CRC']['Accuracy'].append(res['hier_accuracy'])
                alpha_results['CRC']['Coverage'].append(res['coverage'])
            
        TS = '' if not temp_scaling else '_TS'
        for a in threshold_algs:
            results['Architecture'] = model_name+TS
            results[f'{a}_{100*(1-alpha)} Optimal Param Result (mean)'] = np.mean([r for r in alpha_results[a]['Optimal Param']])
            results[f'{a}_{100*(1-alpha)} Accuracy (mean)'] = np.mean([100*r for r in alpha_results[a]['Accuracy']])
            results[f'{a}_{100*(1-alpha)} Accuracy (std)'] = np.std([100*r for r in alpha_results[a]['Accuracy']])
            results[f'{a}_{100*(1-alpha)} Accuracy Error (mean)'] = np.mean([100*(r-(1-alpha)) for r in alpha_results[a]['Accuracy']])
            results[f'{a}_{100*(1-alpha)} Accuracy Error (std)'] = np.std([100*(1-alpha)-100*r for r in alpha_results[a]['Accuracy']])
            results[f'{a}_{100*(1-alpha)} Accuracy Error Abs (mean)'] = np.mean(np.abs([100*(1-alpha)-100*r for r in alpha_results[a]['Accuracy']]))
            results[f'{a}_{100*(1-alpha)} Accuracy Error Abs (std)'] = np.std(np.abs([100*(1-alpha)-100*r for r in alpha_results[a]['Accuracy']]))
            results[f'{a}_{100*(1-alpha)} Coverage (mean)'] = np.mean([100*r for r in alpha_results[a]['Coverage']])
            results[f'{a}_{100*(1-alpha)} Coverage (std)'] = np.std([100*r for r in alpha_results[a]['Coverage']])
    
    return results
    
